{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3781d0eb",
   "metadata": {},
   "source": [
    "- 智能客服\n",
    "    - 用户问题（一次 query）意图识别"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54868963",
   "metadata": {},
   "source": [
    "## pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea8f79c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:22:22.471763Z",
     "start_time": "2023-06-18T04:22:20.422496Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "from time import perf_counter\n",
    "import numpy as np\n",
    "from transformers import pipeline\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebb0d527",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-15T14:11:43.339954Z",
     "start_time": "2023-06-15T14:11:43.336422Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17152673",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:23:53.980362Z",
     "start_time": "2023-06-18T04:23:52.530669Z"
    }
   },
   "outputs": [],
   "source": [
    "bert_ckpt = 'transformersbook/bert-base-uncased-finetuned-clinc'\n",
    "pipe = pipeline('text-classification', model=bert_ckpt, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b3017d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:23:54.970047Z",
     "start_time": "2023-06-18T04:23:54.962881Z"
    }
   },
   "outputs": [],
   "source": [
    "next(iter(pipe.model.parameters())).device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cd63f71",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:23:56.829471Z",
     "start_time": "2023-06-18T04:23:56.024895Z"
    }
   },
   "outputs": [],
   "source": [
    "query = \"\"\"Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in Paris and I need a 15 passenger van\"\"\"\n",
    "pipe(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b3eba3",
   "metadata": {},
   "source": [
    "### 关于 pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78d4dd85",
   "metadata": {},
   "source": [
    "- `pipe.model`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50285cdc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-13T16:01:09.096346Z",
     "start_time": "2023-06-13T16:01:09.086128Z"
    }
   },
   "outputs": [],
   "source": [
    "# classifier head: 151 分类\n",
    "pipe.model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07d27ca8",
   "metadata": {},
   "source": [
    "## 模型性能评估"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "092fb30c",
   "metadata": {},
   "source": [
    "- Model performance\n",
    "    - dataset accuracy\n",
    "- Latency\n",
    "    - query/inference time\n",
    "- Memory\n",
    "    - model size"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2befa626",
   "metadata": {},
   "source": [
    "### datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e6207e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:24:07.910157Z",
     "start_time": "2023-06-18T04:24:07.015898Z"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48618921",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:24:11.312238Z",
     "start_time": "2023-06-18T04:24:08.598496Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clinc = load_dataset(\"clinc_oos\", \"plus\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee7865c4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-15T14:00:25.460237Z",
     "start_time": "2023-06-15T14:00:25.451174Z"
    }
   },
   "outputs": [],
   "source": [
    "clinc['test'][42]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b42e1be",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-15T14:00:26.346877Z",
     "start_time": "2023-06-15T14:00:26.340808Z"
    }
   },
   "outputs": [],
   "source": [
    "# clinc['test'].features['intent']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f5785c1",
   "metadata": {},
   "source": [
    "### metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b8b1c8c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:24:35.511045Z",
     "start_time": "2023-06-18T04:24:35.505851Z"
    }
   },
   "outputs": [],
   "source": [
    "intents = clinc['test'].features['intent']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf24e430",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:24:25.854881Z",
     "start_time": "2023-06-18T04:24:21.330143Z"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "accuracy_score = load_metric('accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce5f48a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:24:27.140187Z",
     "start_time": "2023-06-18T04:24:27.118353Z"
    }
   },
   "outputs": [],
   "source": [
    "class PerformanceBenchmark:\n",
    "    def __init__(self, pipe, dataset, optim_type='BERT baseline'):\n",
    "        self.pipe = pipe\n",
    "        self.dataset = dataset\n",
    "        self.optim_type = optim_type\n",
    "        \n",
    "#     def compute_accuracy(self):\n",
    "#         pass\n",
    "    \n",
    "    def compute_accuracy(self):\n",
    "        preds, labels = [], []\n",
    "        # 可以改造为批次化的 input\n",
    "        for example in tqdm(self.dataset, desc='evaluate on test dataset'):\n",
    "            pred = self.pipe(example['text'])[0]['label']\n",
    "            label = example['intent']\n",
    "            preds.append(intents.str2int(pred))\n",
    "            labels.append(label)\n",
    "        accuracy = accuracy_score.compute(predictions=preds, references=labels)\n",
    "        print(f'Accuracy on test set: {accuracy[\"accuracy\"]:.3f}')\n",
    "        return accuracy\n",
    "    \n",
    "    def compute_size(self):\n",
    "        state_dict = self.pipe.model.state_dict()\n",
    "        tmp_path = Path('model.pth')\n",
    "        torch.save(state_dict, tmp_path)\n",
    "        size_mb = Path(tmp_path).stat().st_size / (1024*1024)\n",
    "        tmp_path.unlink()\n",
    "        print(f'Model size (MB): {size_mb:.2f}')\n",
    "        return {'size_mb': size_mb}\n",
    "    \n",
    "    def time_pipeline(self, query='what is the pin number of my account'):\n",
    "        latencies = []\n",
    "        \n",
    "        # warmup\n",
    "        for _ in range(10):\n",
    "            _ = self.pipe(query)\n",
    "            \n",
    "        # timed run\n",
    "        for _ in range(100):\n",
    "            start_time = perf_counter()\n",
    "            _ = self.pipe(query)\n",
    "            latency = perf_counter() - start_time\n",
    "            latencies.append(latency)\n",
    "        \n",
    "        # run stats\n",
    "        time_avg_time = 1000 * np.mean(latencies)\n",
    "        time_std_time = 1000 * np.std(latencies)\n",
    "        print(f'Average latency (ms): {time_avg_time:.2f} +\\- {time_std_time:.2f}')\n",
    "        return {'time_avg_ms': time_avg_time, 'time_std_ms': time_std_time}\n",
    "    \n",
    "    def run_benchmark(self):\n",
    "        metrics = {}\n",
    "        metrics[self.optim_type] = self.compute_size()\n",
    "        metrics[self.optim_type].update(self.time_pipeline())\n",
    "        metrics[self.optim_type].update(self.compute_accuracy())\n",
    "        return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37a184b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T04:25:08.078619Z",
     "start_time": "2023-06-18T04:24:38.075869Z"
    }
   },
   "outputs": [],
   "source": [
    "benchmark = PerformanceBenchmark(pipe, clinc['test'])\n",
    "benchmark.run_benchmark()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
